{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to Train a Model on MNIST with FiftyOne and Torch\n",
    "This recipe demonstrates how to train a PyTorch model on the **MNIST** dataset using `FiftyOneTorchDataset`. This is useful when you want to build and evaluate models in Torch while managing your data pipeline directly from FiftyOne. Specifically, it covers:\n",
    "\n",
    "* Loading the MNIST dataset from the [Dataset Zoo](https://voxel51.com/docs/fiftyone/user_guide/dataset_zoo/index.html)\n",
    "* Creating train/validation/test splits with FiftyOne’s tagging and random splitting utilities\n",
    "* Building a subset of the dataset for faster experimentation\n",
    "* Running a simple training loop via an external script (`mnist_training.py`)\n",
    "* Saving model weights for later evaluation or reuse\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "If you haven't already, install FiftyOne:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install fiftyone"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we'll use [PyTorch](https://pytorch.org/) for working with tensors and inspecting sample data. To follow along, you'll need to install `torch` and `torchvision`, if necessary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torch torchvision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fiftyone as fo\n",
    "import fiftyone.zoo as foz\n",
    "import fiftyone.utils.random as four"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "import torchvision.transforms.v2 as transforms\n",
    "from torchvision import tv_tensors\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as plt_patches\n",
    "from PIL import Image\n",
    "import urllib.request"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run this recipe, you’ll need the mnist_training.py script, which contains a simple PyTorch training loop. The following cell will automatically download the file into your working directory so it can be imported directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://cdn.voxel51.com/tutorials_torch_dataset_examples/notebook_simple_training_example/mnist_training.py\"\n",
    "urllib.request.urlretrieve(url, \"mnist_training.py\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://cdn.voxel51.com/tutorials_torch_dataset_examples/notebook_the_cache_field_names_argument/utils.py\"\n",
    "urllib.request.urlretrieve(url, \"utils.py\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mnist_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.multiprocessing.set_start_method('forkserver')\n",
    "torch.multiprocessing.set_forkserver_preload(['torch', 'fiftyone'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic Training Example on MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will look at an actual training script with `FiftyOneTorchDataset`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Split 'train' already downloaded\n",
      "Split 'test' already downloaded\n",
      "Loading existing dataset 'mnist'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use\n"
     ]
    }
   ],
   "source": [
    "mnist = foz.load_zoo_dataset(\"mnist\")\n",
    "mnist.persistent = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Session launched. Run `session.show()` to open the App in a cell output.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset:          mnist\n",
       "Media type:       image\n",
       "Num samples:      70000\n",
       "Selected samples: 0\n",
       "Selected labels:  0\n",
       "Session URL:      http://localhost:5151/"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fo.launch_app(mnist, auto=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's say that for our training, we want to define some random subset of our trainset to be a validation set. We can easily do this with FiftyOne."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train': 54000, 'validation': 6000, 'test': 10000}\n"
     ]
    }
   ],
   "source": [
    "# remove existing 'train' or 'validation' tags if they exist\n",
    "mnist.untag_samples(['train', 'validation'])\n",
    "\n",
    "# create a random split, just on the samples not tagged 'test'\n",
    "not_test = mnist.match_tags('test', bool=False)\n",
    "four.random_split(not_test, {'train' : 0.9, 'validation' : 0.1})\n",
    "print(mnist.count_sample_tags())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# subset if we want it\n",
    "samples = []\n",
    "samples += mnist.match_tags('train').take(1000).values('id')\n",
    "for tag in ['test', 'validation']:\n",
    "    samples += mnist.match_tags(tag).values('id')\n",
    "\n",
    "subset = mnist.select(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Train Loss =   3.999074: 100%|██████████| 63/63 [00:01<00:00, 58.45it/s]\n",
      "Average Validation Loss =   2.811698: 100%|██████████| 375/375 [00:02<00:00, 149.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New best lost achieved : 2.801392190893491. Saving model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Train Loss =   1.072026: 100%|██████████| 63/63 [00:00<00:00, 119.78it/s]\n",
      "Average Validation Loss =   0.396746: 100%|██████████| 375/375 [00:01<00:00, 215.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New best lost achieved : 0.39641891201337176. Saving model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Train Loss =   0.148484: 100%|██████████| 63/63 [00:00<00:00, 120.53it/s]\n",
      "Average Validation Loss =   0.319500: 100%|██████████| 375/375 [00:01<00:00, 211.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New best lost achieved : 0.3149221637323499. Saving model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Train Loss =   0.627752: 100%|██████████| 63/63 [00:00<00:00, 97.89it/s] \n",
      "Average Validation Loss =   0.304854: 100%|██████████| 375/375 [00:01<00:00, 207.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New best lost achieved : 0.2977131818582614. Saving model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Train Loss =   0.204026: 100%|██████████| 63/63 [00:00<00:00, 119.48it/s]\n",
      "Average Validation Loss =   0.210062: 100%|██████████| 375/375 [00:01<00:00, 214.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New best lost achieved : 0.2064167803612848. Saving model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Train Loss =   0.070824: 100%|██████████| 63/63 [00:00<00:00, 106.55it/s]\n",
      "Average Validation Loss =   1.467735: 100%|██████████| 375/375 [00:02<00:00, 173.34it/s]\n",
      "Average Train Loss =   0.509837: 100%|██████████| 63/63 [00:00<00:00, 112.51it/s]\n",
      "Average Validation Loss =   0.387830: 100%|██████████| 375/375 [00:02<00:00, 163.92it/s]\n",
      "Average Train Loss =   0.236021: 100%|██████████| 63/63 [00:00<00:00, 116.83it/s]\n",
      "Average Validation Loss =   0.287110: 100%|██████████| 375/375 [00:01<00:00, 211.45it/s]\n",
      "Average Train Loss =   0.047093: 100%|██████████| 63/63 [00:00<00:00, 99.11it/s] \n",
      "Average Validation Loss =   0.156705: 100%|██████████| 375/375 [00:01<00:00, 213.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New best lost achieved : 0.14917240004179377. Saving model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Train Loss =   0.009842: 100%|██████████| 63/63 [00:00<00:00, 97.05it/s] \n",
      "Average Validation Loss =   0.138089: 100%|██████████| 375/375 [00:01<00:00, 211.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New best lost achieved : 0.13520573990046977. Saving model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Validation Loss =   0.113355: 100%|██████████| 625/625 [00:10<00:00, 61.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final Test Results:\n",
      "Loss = 0.11413920720983296\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    0 - zero       0.98      0.97      0.98       980\n",
      "     1 - one       0.98      0.99      0.99      1135\n",
      "     2 - two       0.96      0.97      0.96      1032\n",
      "   3 - three       0.95      0.97      0.96      1010\n",
      "    4 - four       0.96      0.97      0.96       982\n",
      "    5 - five       0.95      0.96      0.95       892\n",
      "     6 - six       0.96      0.97      0.96       958\n",
      "   7 - seven       0.97      0.93      0.95      1028\n",
      "   8 - eight       0.98      0.94      0.96       974\n",
      "    9 - nine       0.95      0.96      0.96      1009\n",
      "\n",
      "    accuracy                           0.96     10000\n",
      "   macro avg       0.96      0.96      0.96     10000\n",
      "weighted avg       0.96      0.96      0.96     10000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "path_to_save_weights = '/path/to/save/weights'\n",
    "mnist_training.main(subset, 10, 10, device, path_to_save_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This recipe showed how to train a PyTorch model on MNIST using `FiftyOneTorchDataset`, with dataset splits, subsets, and a simple training loop."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch-dataset",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
